- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[MLIR][NVVM] Update convert Ops to use builtin types #159704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Update convert Ops to use builtin types #159704
Conversation
This change updates the `convert.f32x2.to.f6x2`, `convert.f32x2.to.f8x2`, `convert.f16x2.to.f8x2`, and `convert.bf16x2.to.f8x2` Ops to use builtin types for the destination types as a `TypeAttr` instead of custom enums. The corresponding tests are updated to reflect the changes in the assembly format.
| @llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Srinivasa Ravi (Wolfram70) ChangesThis change updates the  The corresponding tests are updated to reflect the changes in the assembly format. Patch is 30.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159704.diff 5 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8537c7030aa8f..c540c5ccf50bf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Dialect/LLVMIR/LLVMTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
 
 def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
 def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -1258,18 +1259,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
   }];
 }
 
-def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
-def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
-
-def ConvertFP6Type : I32EnumAttr<"ConvertFP6Type", "NVVM ConvertFP6Type kind",
-  [ConvertFP6E2M3, ConvertFP6E3M2]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP6TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP6Type, "convert_fp6_type"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
   let summary = "Convert a pair of float inputs to f6x2";
   let description = [{
@@ -1290,19 +1279,20 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
 
   let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
   let arguments = (ins 
-    ConvertFP6TypeAttr:$type,
     F32:$a,
     F32:$b,
-    DefaultValuedAttr<BoolAttr, "false">:$relu);
-  let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+    DefaultValuedAttr<BoolAttr, "false">:$relu,
+    TypeAttr:$dstTy);
+  let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+  let hasVerifier = 1;
   
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP6Type,
+    static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
                                               bool hasRelu);
   }];
 
   string llvmBuilder = [{
-    auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($type, $relu);
+    auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
     llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
     if(op.getDst().getType().isInteger(16))
       $dst = packedI16;
@@ -1312,19 +1302,6 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
   }];
 }
 
-def ConvertFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
-def ConvertFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
-def ConvertFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
-
-def ConvertFP8Type : I32EnumAttr<"ConvertFP8Type", "NVVM ConvertFP8Type kind",
-  [ConvertFP8E4M3, ConvertFP8E5M2, ConvertFP8UE8M0]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP8TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP8Type, "convert_fp8_type"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
   let summary = "Convert a pair of float inputs to f8x2";
   let description = [{
@@ -1346,23 +1323,23 @@ def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
   let hasVerifier = 1;
   let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
   let arguments = (ins
-    ConvertFP8TypeAttr:$type,
     F32:$a,
     F32:$b,
     DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
     DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
-    DefaultValuedAttr<BoolAttr, "false">:$relu);
-  let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+    DefaultValuedAttr<BoolAttr, "false">:$relu,
+    TypeAttr:$dstTy);
+  let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
 
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
+    static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
                                               NVVM::FPRoundingMode rnd,
                                               NVVM::SaturationMode sat,
                                               bool hasRelu);
   }];
   
   string llvmBuilder = [{
-    auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
+    auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
     llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
     if(op.getDst().getType().isInteger(16))
       $dst = packedI16;
@@ -1394,18 +1371,18 @@ def NVVM_ConvertF16x2ToF8x2Op : NVVM_Op<"convert.f16x2.to.f8x2"> {
   let hasVerifier = 1;
   let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
   let arguments = (ins
-    ConvertFP8TypeAttr:$type,
     VectorOfLengthAndType<[2], [F16]>:$a,
-    DefaultValuedAttr<BoolAttr, "false">:$relu);
-  let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
+    DefaultValuedAttr<BoolAttr, "false">:$relu,
+    TypeAttr:$dstTy);
+  let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
 
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
+    static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
                                               bool hasRelu);
   }];
 
   string llvmBuilder = [{
-    auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($type, $relu);
+    auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($dstTy, $relu);
     llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
     if(op.getDst().getType().isInteger(16))
       $dst = packedI16;
@@ -1437,11 +1414,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
   let hasVerifier = 1;
   let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
   let arguments = (ins
-    ConvertFP8TypeAttr:$type,
     VectorOfLengthAndType<[2], [BF16]>:$a,
     DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
-    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
-  let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+    TypeAttr:$dstTy);
+  let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
   
   let extraClassDeclaration = [{
     static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 77ec1ebde3109..28fa3f2a098e0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -189,6 +189,14 @@ LogicalResult ConvertFloatToTF32Op::verify() {
   return success();
 }
 
+LogicalResult ConvertF32x2ToF6x2Op::verify() {
+  if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+    return emitError("Only f6E2M3FN and f6E3M2FN types are supported for "
+                     "ConvertF32x2ToF6x2Op.");
+  }
+  return success();
+}
+
 LogicalResult ConvertF32x2ToF8x2Op::verify() {
   using RndMode = NVVM::FPRoundingMode;
   using SatMode = NVVM::SaturationMode;
@@ -200,41 +208,52 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
 
   bool hasRelu = getRelu();
 
-  switch (getType()) {
-  case ConvertFP8Type::E4M3:
-  case ConvertFP8Type::E5M2:
-    if (!isRoundingModeRN)
-      return emitOpError("Only RN rounding mode is supported for conversions "
-                         "from f32x2 to .e4m3x2 or .e5m2x2 types");
-    if (!isSatFinite)
-      return emitOpError("Only SATFINITE saturation mode is supported for "
-                         "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
-    break;
-  case ConvertFP8Type::UE8M0:
-    if (!(isRoundingModeRZ || isRoundingModeRP))
-      return emitOpError("Only RZ or RP rounding modes are supported for "
-                         "conversions from f32x2 to .ue8m0x2 type");
-    if (hasRelu)
-      return emitOpError("relu not supported for conversions to .ue8m0x2 type");
-    break;
-  }
-  return success();
+  return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+      .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+          [&](mlir::Type) -> LogicalResult {
+            if (!isRoundingModeRN) {
+              return emitOpError(
+                  "Only RN rounding mode is supported for conversions from "
+                  "f32x2 to f8E4M3FNx2 or f8E5M2x2 types");
+            }
+            if (!isSatFinite) {
+              return emitOpError(
+                  "Only SATFINITE saturation mode is supported for conversions "
+                  "from f32x2 to f8E4M3FNx2 or f8E5M2x2 types");
+            }
+            return success();
+          })
+      .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+        if (!(isRoundingModeRZ || isRoundingModeRP)) {
+          return emitOpError("Only RZ or RP rounding modes are supported for "
+                             "conversions from f32x2 to f8E8M0FNUx2 type");
+        }
+        if (hasRelu) {
+          return emitOpError(
+              "relu not supported for conversions to f8E8M0FNUx2 type");
+        }
+        return success();
+      })
+      .Default([this](mlir::Type) {
+        return emitOpError("Only f8e4m3fn, f8e5m2, and f8e8m0fnu types are "
+                           "supported for conversions from f32x2 to f8x2");
+      });
 }
 
 LogicalResult ConvertF16x2ToF8x2Op::verify() {
-  if (getType() == ConvertFP8Type::UE8M0)
-    return emitOpError("Only .e4m3 or .e5m2 types are supported for "
+  if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
+    return emitOpError("Only f8E4M3FN or f8E5M2 types are supported for "
                        "conversions from f16x2 to f8x2.");
-
+  }
   return success();
 }
 
 LogicalResult ConvertBF16x2ToF8x2Op::verify() {
   using RndMode = NVVM::FPRoundingMode;
 
-  if (getType() != ConvertFP8Type::UE8M0)
-    return emitOpError(
-        "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
+  if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
+    return emitOpError("Only f8E8M0FNU type is supported for conversions from "
+                       "bf16x2 to f8x2.");
 
   auto rnd = getRnd();
   if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1714,15 +1733,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite            \
            : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
 
-llvm::Intrinsic::ID
-ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
-  switch (type) {
-  case NVVM::ConvertFP6Type::E2M3:
-    return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
-  case NVVM::ConvertFP6Type::E3M2:
-    return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
-  }
-  llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+                                                         bool hasRelu) {
+  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+      .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+        return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
+      })
+      .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+        return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
+      })
+      .Default([](mlir::Type) {
+        llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+        return llvm::Intrinsic::not_intrinsic;
+      });
 }
 
 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)                                 \
@@ -1734,41 +1757,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
            : llvm::Intrinsic::nvvm_ff_to_##type##_rn
 
 llvm::Intrinsic::ID
-ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
-                                     NVVM::FPRoundingMode rnd,
+ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
                                      NVVM::SaturationMode sat, bool hasRelu) {
   bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
   bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
   bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
 
-  switch (type) {
-  case NVVM::ConvertFP8Type::E4M3:
-    return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
-  case NVVM::ConvertFP8Type::E5M2:
-    return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
-  case NVVM::ConvertFP8Type::UE8M0:
-    if (hasRoundingModeRZ)
-      return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
-    else if (hasRoundingModeRP)
-      return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
-  }
-  llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
+  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+      .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+        return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
+      })
+      .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+        return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
+      })
+      .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+        if (hasRoundingModeRZ)
+          return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
+        else if (hasRoundingModeRP)
+          return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
+
+        llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+      })
+      .Default([](mlir::Type) {
+        llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+        return llvm::Intrinsic::not_intrinsic;
+      });
 }
 
 #define GET_F16x2_TO_F8X2_ID(type, has_relu)                                   \
   has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu                   \
            : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
 
-llvm::Intrinsic::ID
-ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
-  switch (type) {
-  case NVVM::ConvertFP8Type::E4M3:
-    return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
-  case NVVM::ConvertFP8Type::E5M2:
-    return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
-  default:
-    llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
-  }
+llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+                                                         bool hasRelu) {
+  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+      .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+        return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
+      })
+      .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+        return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
+      })
+      .Default([](mlir::Type) {
+        llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
+        return llvm::Intrinsic::not_intrinsic;
+      });
 }
 
 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)                                   \
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 04163b578aa02..99289923b58b1 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -3,9 +3,9 @@
 // CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
 llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
   //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
-  %res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
+  %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
   //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
-  %res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
+  %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
   llvm.return
 }
 
@@ -13,9 +13,9 @@ llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
 llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
   //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
   //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
-  %res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
+  %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E2M3FN)
   //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
   //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
-  %res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
+  %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN)
   llvm.return
 }
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index 4a15efb9e805c..de21826445afb 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -5,31 +5,31 @@
 // CHECK-LABEL: @convert_f32x2_to_f8x2_e4m3
 llvm.func @convert_f32x2_to_f8x2_e4m3(%srcA : f32, %srcB : f32) {
   // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
-  %res1 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+  %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
   // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
-  %res2 = nvvm.convert.f32x2.to.f8x2 <e4m3> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+  %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E4M3FN)
   llvm.return
 }
 
 // CHECK-LABEL: @convert_f32x2_to_f8x2_e5m2
 llvm.func @convert_f32x2_to_f8x2_e5m2(%srcA : f32, %srcB : f32) {
   // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
-  %res1 = nvvm.convert.f32x2.to.f8x2 <e5m2> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+  %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
   // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
-  %res2 = nvvm.convert.f32x2.to.f8x2 <e5m2> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+  %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E5M2)
   llvm.return
 }
 
 // CHECK-LABEL: @convert_f32x2_to_f8x2_ue8m0
 llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) {
   // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
-  %res1 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : i16
+  %res1 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : i16 (f8E8M0FNU)
   // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
-  %res2 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : i16
+  %res2 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : i16 (f8E8M0FNU)
   // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz.satfinite(float %{{.*}}, float %{{.*}})
-  %res3 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16
+  %res3 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E8M0FNU)
   // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp.satfinite(float %{{.*}}, float %{{.*}})
-  %res4 = nvvm.convert.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16
+  %res4 = nvvm.convert.f32x2.to.f8x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16 (f8E8M0FNU)
   llvm.return
 }
 
@@ -37,10 +37,10 @@ llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) {
 llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) {
   // CHECK: %[[res1:.*]] = call i...
[truncated]
 | 
| return emitOpError("Only ") | ||
| << mlir::Float6E2M3FNType::get(ctx) << " and " | ||
| << mlir::Float6E3M2FNType::get(ctx) | ||
| << " types are supported for conversions from f32x2 to f6x2."; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Not for this PR]:
I am wondering if there is a way to specify this type constraint (line 195) in the Op itself (in tablegen).
That way, the default case checks can all happen within the td file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| LLVM Buildbot has detected a new failure on builder  Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/19479 Here is the relevant piece of the build log for the reference | 
This change updates the `convert.f32x2.to.f6x2`, `convert.f32x2.to.f8x2`, `convert.f16x2.to.f8x2`, and `convert.bf16x2.to.f8x2` Ops to use builtin types for the destination types as a `TypeAttr` instead of custom enums. The corresponding tests are updated to reflect the changes in the assembly format.
This change updates the `convert.f32x2.to.f6x2`, `convert.f32x2.to.f8x2`, `convert.f16x2.to.f8x2`, and `convert.bf16x2.to.f8x2` Ops to use builtin types for the destination types as a `TypeAttr` instead of custom enums. The corresponding tests are updated to reflect the changes in the assembly format.
This change updates the `convert.f32x2.to.f6x2`, `convert.f32x2.to.f8x2`, `convert.f16x2.to.f8x2`, and `convert.bf16x2.to.f8x2` Ops to use builtin types for the destination types as a `TypeAttr` instead of custom enums. The corresponding tests are updated to reflect the changes in the assembly format.
This change updates the `convert.f32x2.to.f6x2`, `convert.f32x2.to.f8x2`, `convert.f16x2.to.f8x2`, and `convert.bf16x2.to.f8x2` Ops to use builtin types for the destination types as a `TypeAttr` instead of custom enums. The corresponding tests are updated to reflect the changes in the assembly format.
This change updates the
convert.f32x2.to.f6x2,convert.f32x2.to.f8x2,convert.f16x2.to.f8x2, andconvert.bf16x2.to.f8x2Ops to use builtin types for the destination types as aTypeAttrinstead of custom enums.The corresponding tests are updated to reflect the changes in the assembly format.